# coding: utf-8
import timeit
import wandb
import numpy as np
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch
from utils import accuracy

from timm.data.loader import create_loader_sampler
from timm.data import Mixup, FastCollateMixup
from timm import utils
from timm.loss.cross_entropy import SoftTargetCrossEntropy, LabelSmoothingCrossEntropy
from timm.scheduler.scheduler_factory import create_scheduler_v2, scheduler_kwargs
from timm.optim.optim_factory import create_optimizer_v2, optimizer_kwargs

import logging

from PIL import ImageFile 
ImageFile.LOAD_TRUNCATED_IMAGES = True 

from timm.data.loader import create_loader_sampler
from timm.data import resolve_data_config

class MainNN(object):
    def __init__(self, loop, n_data, num_epochs, batch_size_training, batch_size_test,
                 n_model, flag_nccl, rank, n_gpu, device, flag_wandb, flag_defaug,
                   flag_acc5,  flag_lr_schedule, path, 
                   patch_size, dim, dim_token, permute, expansion_factor , L, args):

        self.seed = args.seed + loop
        self.train_loader = None
        self.eval_loader = None
        self.n_data = n_data  # dataset
        self.num_classes = 1000 ### for imagenet
        self.num_channel = 0
        self.num_epochs = num_epochs
        self.batch_size_training = batch_size_training
        self.batch_size_test = batch_size_test
        self.n_model = n_model
        self.loss_training_batch = None  # minibatch loss
        self.flag_wandb = flag_wandb  # weights and biases
        self.flag_nccl = flag_nccl
        self.rank = rank
        self.n_gpu = n_gpu
        self.device = device
        self.flag_defaug = flag_defaug  # Default augmentation (cifar, svhn)
        #self.epsilon = epsilon
        self.flag_acc5 = flag_acc5
        self.flag_lr_schedule = flag_lr_schedule
        self.path = path
        ### mlpmixer
        self.patch_size = patch_size
        self.dim = dim
        self.dim_token = dim_token
        self.permute = permute
        self.expansion_factor=expansion_factor
        self.L=L
        self.dim_ppfc = args.dim_ppfc
        
        self.args = args
    def run_main(self):
        args= self.args
        """Settings for random number"""
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)

        """Log, Weights and Biases"""
        flag_log = 0
        if self.flag_nccl == 1:
            if self.rank == 0:
                flag_log = 1
        else:
            flag_log = 1


        """neural network model"""
        model = None
        if self.n_model == 'ResNet18':
                model = torchvision.models.resnet18()
        elif self.n_model == 'ResNet50':
                model = torchvision.models.resnet50()
        elif self.n_model == 'ResNet101':
                model = torchvision.models.resnet101()
        elif self.n_model in ["MLPMixer", "mlpmixer", "mlp-mixer", "mlp_mixer"]:
                import sys, os
                sys.path.append( os.path.dirname(__file__) + "/../" )
                from bmlp.models import MLPMixer, MLPMixerSep
                if args.fix == 3:
                    model= MLPMixerSep(image_size=224, channels=3, patch_size=self.patch_size, 
                            dim=self.dim, dim_token=self.dim_token, 
                            depth=self.L, num_classes=1000,
                            expansion_factor = self.expansion_factor, 
                            expansion_factor_token = self.expansion_factor, 
                            permute_per_blocks=self.permute,  
                            dim_ppfc=self.dim_ppfc)
                else:
                    model= MLPMixer(image_size=224, channels=3, patch_size=self.patch_size, 
                            dim=self.dim, dim_token=self.dim_token, 
                            depth=self.L, num_classes=1000,
                            expansion_factor = self.expansion_factor, 
                            expansion_factor_token = self.expansion_factor, 
                            permute_per_blocks=self.permute,
                            remove_ppfc=args.remove_ppfc,)
                
        else:
            raise ValueError()
        """GPU setting"""
        if self.flag_nccl == 1:
            device = self.device
            model = model.to(device)
            model = DDP(model, device_ids=[device])
        else:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            model = model.to(device)
            if device == 'cuda':
                torch.backends.cudnn.benchmark = True
                print('GPU={}'.format(torch.cuda.device_count()))




        # setup mixup / cutmix
        if flag_log:
            print("setup mixup/cutmix...")
        collate_fn = None
        mixup_fn = None
        mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
        if mixup_active:
            mixup_args = dict(
                mixup_alpha=args.mixup,
                cutmix_alpha=args.cutmix,
                cutmix_minmax=args.cutmix_minmax,
                prob=args.mixup_prob,
                switch_prob=args.mixup_switch_prob,
                mode=args.mixup_mode,
                label_smoothing=args.smoothing,
                num_classes=args.num_classes
            )
            if args.prefetcher:
                collate_fn = FastCollateMixup(**mixup_args)
            else:
                mixup_fn = Mixup(**mixup_args)




        data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))

        """dataset"""
        if flag_log:
            print("setup dataset...")
        if self.flag_nccl == 1:
            train_dataset = torchvision.datasets.ImageFolder(root='{}/train'.format(self.path), transform=None)
        else:
            raise ValueError()
        #self.num_channel, self.num_classes = train_dataset.get_info(n_data=self.n_data)
        self.num_channel=args.in_chans
        self.num_classes=args.num_classes


        eval_dataset = torchvision.datasets.ImageFolder(root='{}/val'.format(self.path), transform=None)


        train_sampler = None
        eval_sampler = None

        num_workers = args.workers

        pin = False

        if self.flag_nccl == 1:
            pin = True
            train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=self.n_gpu, rank=self.rank)
            eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset, num_replicas=self.n_gpu, rank=self.rank)
            

            train_interpolation = args.train_interpolation
            if args.no_aug or not train_interpolation:
                train_interpolation = data_config['interpolation']
            self.train_loader = create_loader_sampler(
                sampler = train_sampler,
                dataset = train_dataset,
                input_size=data_config['input_size'],
                batch_size=self.batch_size_training,
                is_training=True,
                use_prefetcher=args.prefetcher,
                no_aug=args.no_aug,
                re_prob=args.reprob,
                re_mode=args.remode,
                re_count=args.recount,
                re_split=args.resplit,
                scale=args.scale,
                ratio=args.ratio,
                hflip=args.hflip,
                vflip=args.vflip,
                color_jitter=args.color_jitter,
                auto_augment=args.aa,
                num_aug_repeats=args.aug_repeats,
                num_aug_splits=0,
                interpolation=train_interpolation,
                mean=data_config['mean'],
                std=data_config['std'],
                num_workers=args.workers,
                distributed=args.distributed,
                collate_fn=collate_fn,
                pin_memory=args.pin_mem,
                device=device,
                use_multi_epochs_loader=args.use_multi_epochs_loader,
                worker_seeding=args.worker_seeding,
            )

            eval_workers = args.workers
            self.eval_loader = create_loader_sampler(
                sampler=eval_sampler,
                dataset=   eval_dataset,
                input_size=data_config['input_size'],
                batch_size=args.validation_batch_size or args.batch_size,
                is_training=False,
                use_prefetcher=args.prefetcher,
                interpolation=data_config['interpolation'],
                mean=data_config['mean'],
                std=data_config['std'],
                num_workers=eval_workers,
                distributed=args.distributed,
                crop_pct=data_config['crop_pct'],
                pin_memory=args.pin_mem,
                device=device,
            )


        else:
            self.train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=self.batch_size_training, sampler=train_sampler,
                                                            shuffle=True, num_workers=num_workers, pin_memory=pin)
            self.eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=self.batch_size_test, sampler=eval_sampler,
                                                           shuffle=False, num_workers=num_workers, pin_memory=pin)

        if mixup_active:
            train_loss_fn = SoftTargetCrossEntropy()
        elif args.smoothing:
            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
        else:
            train_loss_fn = nn.CrossEntropyLoss()

        train_loss_fn = train_loss_fn.to(device=self.device)

        validate_loss_fn = nn.CrossEntropyLoss().to(self.device)




        
        
        if not args.lr:
            global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps
            batch_ratio = global_batch_size / args.lr_base_size
            if not args.lr_base_scale:
                on = args.opt.lower()
                args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
            if args.lr_base_scale == 'sqrt':
                batch_ratio = batch_ratio ** 0.5
            args.lr = args.lr_base * batch_ratio
            if utils.is_primary(args):
                print(
                    f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
                    f'and effective global batch size ({global_batch_size}) and world size ({args.world_size}) with {args.lr_base_scale} scaling.')

        optimizer = create_optimizer_v2(
        model,
        **optimizer_kwargs(cfg=args),
        **args.opt_kwargs,
        )


        """Learning rate schedule"""
        if flag_log:
            print("set up scheduler...")
        if self.flag_lr_schedule == 1:  # Cosine Annealing
            updates_per_epoch = (len(self.train_loader) + args.grad_accum_steps - 1) // args.grad_accum_steps
            lr_scheduler, num_epochs = create_scheduler_v2(
                optimizer,
                **scheduler_kwargs(args),
                updates_per_epoch=updates_per_epoch,
            )
            start_epoch = 0
            if args.start_epoch is not None:
                start_epoch = args.start_epoch
            if lr_scheduler is not None and start_epoch > 0:
                if args.sched_on_updates:
                    lr_scheduler.step_update(start_epoch * updates_per_epoch)
                else:
                    lr_scheduler.step(start_epoch)


            if utils.is_primary(args):
                print(
                    f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')

            """Initialization"""
            results = np.zeros((self.num_epochs, 4))
            start_time = timeit.default_timer()

        step = 0
        for epoch in range(self.num_epochs):

            if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
                if args.prefetcher and self.train_loader.mixup_enabled:
                    self.train_loader.mixup_enabled = False
                elif mixup_fn is not None:
                    mixup_fn.mixup_enabled = False

            """Training"""
            start_epoch_time = timeit.default_timer()


            if self.flag_nccl == 1:
                self.train_loader.sampler.set_epoch(epoch)

            loss_training = None
            loss_training_all = 0
            loss_test_all = 0
            correct_train = 0

            total_steps = len(self.train_loader)

            for i, (input, target) in enumerate(self.train_loader):
                model.train()
                input = input.to(device)
                target = target.to(device)
                if mixup_fn is not None:
                    input, target = mixup_fn(input, target)


                """Main training"""
                outputs = model(input)
                loss_training = train_loss_fn(outputs, target)
                optimizer.zero_grad()
                loss_training.backward()  # compute gradients


                optimizer.step()

                loss_training_all += loss_training.item() * input.shape[0]  # Sum of losses within this minibatch

                logits = F.log_softmax(outputs, dim=1)
                if mixup_fn is None:                
                    acc = (logits.max(1)[1] == target).sum()                    
                else:                    
                    acc = (logits.max(1)[1] == target.max(1)[1]).sum()                    
                correct_train += acc.item()

                """Update learning rate per step"""
                step += 1
                if lr_scheduler is not None:
                    if args.sched_on_updates:
                            lr_scheduler.step_update(epoch * updates_per_epoch + float(step))

            """Update Learning rate per epoch"""
            if lr_scheduler is not None:
                if not args.sched_on_updates:
                    lr_scheduler.step(epoch+1)
           
            
            if self.flag_nccl == 1:
                tensor = torch.tensor(loss_training_all, device=device)
                dist.reduce(tensor, dst=0)
                loss_training_all = float(tensor)

                tensor = torch.tensor(correct_train, device=device)
                dist.reduce(tensor, dst=0)
                correct_train = float(tensor)

            loss_training_each = loss_training_all / len(self.train_loader.dataset)
            top1_avg_train = 100.0 * correct_train / len(self.train_loader.dataset)

            step = 0
            """Test"""
            model.eval()
            with torch.no_grad():
                correct = 0
                if self.flag_acc5 == 1:
                    correct_top5 = 0

                for i, (input, target) in enumerate(self.eval_loader):
                    input = input.to(device)
                    target = target.to(device)

                    outputs = model(input)

                    logits = F.log_softmax(outputs, dim=1)
                    acc = (logits.max(1)[1] == target).sum()
                    correct += acc.item()

                    loss_test = validate_loss_fn(outputs, target)
                    loss_test_all += loss_test.item() * outputs.shape[0]

                    if self.flag_acc5 == 1:
                        acc1, acc5 = accuracy(outputs, target, topk=(1, 5))
                        correct_top5 += acc5[0].item()

                """Compute test results"""
                if self.flag_nccl == 1:
                    tensor = torch.tensor(loss_test_all, device=device)
                    dist.reduce(tensor, dst=0)
                    loss_test_all = float(tensor)

                    tensor = torch.tensor(correct, device=device)
                    dist.reduce(tensor, dst=0)
                    correct = float(tensor)

                    if self.flag_acc5 == 1:
                        tensor = torch.tensor(correct_top5, device=device)
                        dist.reduce(tensor, dst=0)
                        correct_top5 = float(tensor)

                loss_test_each = loss_test_all / len(self.eval_loader.dataset)

                top1_avg = 100.0 * correct / len(self.eval_loader.dataset)
                if self.flag_acc5 == 1:
                    top5_avg = 100.0 * correct_top5 / len(self.eval_loader.dataset)

                """Run time"""
                end_epoch_time = timeit.default_timer()
                epoch_time = end_epoch_time - start_epoch_time

                """Show results for each epoch"""
                if flag_log == 1:
                    if self.flag_acc5 == 1:
                        print('Epoch [{}/{}], Training Acc: {:.3f} %, Training Loss: {:.4f}, Test Acc: {:.3f} %, Test Acc5: {:.3f} %, Test Loss: {:.4f}, Epoch Time: {:.2f}s'.
                              format(epoch + 1, self.num_epochs, top1_avg_train, loss_training_each, top1_avg, top5_avg, loss_test_each, epoch_time))
                    else:
                        print('Epoch [{}/{}], Training Acc: {:.3f} %, Training Loss: {:.4f}, Test Acc: {:.3f} %, Test Loss: {:.4f}, Epoch Time: {:.2f}s'.
                              format(epoch + 1, self.num_epochs, top1_avg_train, loss_training_each, top1_avg, loss_test_each, epoch_time))

                    if self.flag_wandb == 1:
                        wandb.log({"epoch": epoch,
                                       "loss_training": loss_training_each,
                                       "test_acc": top1_avg,
                                       "loss_test": loss_test_each,
                                       "lr":  optimizer.param_groups[0]["lr"]
                                       })

                    results[epoch][0] = loss_training_each
                    results[epoch][1] = top1_avg
                    results[epoch][2] = loss_test_each
                    results[epoch][3] = epoch_time

        top1_avg_max = np.max(results[:, 1])
        top1_avg_max_index = np.argmax(results[:, 1])
        top5_avg_max = 0
        if self.flag_acc5 == 1:
            top5_avg_max = np.max(results[:, 2])
        loss_training_bestacc = results[top1_avg_max_index, 0]
        loss_test_bestacc = results[top1_avg_max_index, 2]

        end_time = timeit.default_timer()

        if flag_log == 1:
            print(' ran for %.4fm' % ((end_time - start_time) / 60.))

            if self.flag_acc5 == 1:
                np.savetxt('results/acc/data_%s_model_%s_flag_defaug_%s_seed_%s_besttop1_%s_besttop5_%s.txt'
                           % (self.n_data, self.n_model,  self.flag_defaug, self.seed, top1_avg_max, top5_avg_max),
                           np.zeros(2))

        if flag_log == 1 and self.flag_wandb == 1:
            wandb.run.summary["best_accuracy"] = top1_avg_max
            if self.flag_acc5 == 1:
                wandb.run.summary["best_accuracy_top5"] = top5_avg_max
            wandb.run.summary["loss_training_bestacc"] = loss_training_bestacc
            wandb.run.summary["loss_test_bestacc"] = loss_test_bestacc

            wandb.finish()

